"""
/** Copyright 2020 Zhejiang Lab and Zhejiang University. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
"""
from django.test import TestCase

from task.models import Algorithm, AlgorithmField
from user.models import User
from django.core.exceptions import ValidationError
from django.db.utils import IntegrityError
from django.test import Client

class AlgModelTests(TestCase):
    @classmethod
    def setUpTestData(cls):
        cls.alg = Algorithm.objects.create(alg_name='sbm')
        cls.alg2 = Algorithm.objects.create(alg_name='kd')
        AlgorithmField.objects.create(alg=cls.alg, field_name='lr', field_value='')
        

    def testUniqueSet(self):
        try:
            AlgorithmField.objects.create(alg=AlgModelTests.alg, field_name='lr')
        except IntegrityError:
            pass
        except Exception:
            self.assertEqual(True, False)
        else:
            self.assertEqual(True, False)

    def testUniqueSet2(self):
        try:
            AlgorithmField.objects.create(alg=AlgModelTests.alg2, field_name='lr')
        except Exception:
            self.assertEqual(True, False)

class AlgAPITests(TestCase):
    @classmethod
    def setUpTestData(cls):
        user = User.objects.create_superuser(username='testadmin', email='test@zju.edu.cn', password='testpassword')
        cls.alg = Algorithm.objects.create(alg_name='sbm')
        cls.alg2 = Algorithm.objects.create(alg_name='kd')
        AlgorithmField.objects.create(alg=cls.alg, field_name='alpha', field_value='0.5', field_note='alphanote')
        AlgorithmField.objects.create(alg=cls.alg, field_name='beta', field_value='0.1', field_note='betanote')
        AlgorithmField.objects.create(alg=cls.alg2, field_name='lr', field_value='0.1', field_note='lrnote')
        cls.client = Client()
        cls.client.login(username='testadmin', password='testpassword')

    def testPostAlgInfo(self):
        response = AlgAPITests.client.post('/api/task/algorithm_info', {
            'alg_name': 'cfl',
            'fields': [{
                'field_name': 'alpha',
                'field_value': '0.5',
                'field_note': 'alphanote',
            },
            {
                'field_name': 'beta',
                'field_value': '0.1',
                'field_note': 'betanote',
            }]
        }, 'application/json')
        self.assertEqual(response.status_code, 201)
        fields = list(AlgorithmField.objects.filter(alg__alg_name='cfl').values('field_name', 'field_value', 'field_note'))

        self.assertEqual(fields, [{
                'field_name': 'alpha',
                'field_value': '0.5',
                'field_note': 'alphanote',
            },
            {
                'field_name': 'beta',
                'field_value': '0.1',
                'field_note': 'betanote',
            }])

    def testPutGetAlgInfo(self):
        response = AlgAPITests.client.put('/api/task/algorithm_info', {
            'alg_name': 'sbm',
            'fields': [{
                'field_name': 'alpha',
                'field_value': '0.9',
                'field_note': 'alphanote',
            },
            {
                'field_name': 'gamma',
                'field_value': '0.3',
                'field_note': 'gammanote',
            }]
        }, 'application/json')
        response = AlgAPITests.client.get('/api/task/algorithm_info')
        fields = None
        for alg in response.data:
            if alg['alg_name'] == 'sbm':
                fields = alg['fields']
        self.assertEqual(fields, [
            {'field_name': 'alpha', 'field_value': '0.9', 'field_note': 'alphanote'},
            {'field_name': 'gamma', 'field_value': '0.3', 'field_note': 'gammanote'},])

    def testDeleteAlgInfo(self):
        alg_id = Algorithm.objects.get(alg_name='sbm').id
        response = AlgAPITests.client.delete('/api/task/algorithm_info?alg_ids[]={}'.format(alg_id))
        self.assertEqual(response.status_code, 200)
        algs = Algorithm.objects.all().values_list('alg_name', flat=True)
        self.assertEqual(list(algs), ['kd'])
        fnames = AlgorithmField.objects.all().values_list('field_name', flat=True)
        self.assertEqual(list(fnames), ['lr'])

    def testGetReorgTaskOptions(self):
        response = AlgAPITests.client.get('/api/task/reorg_task_options')
        target = {'datasets': [], 'models': [], 'algorithms': [
                {'id': 2, 'alg_name': 'kd', 'fields': [
                        {'field_name': 'lr', 'field_value': '0.1', 'field_note': 'lrnote'
                        }
                    ]
                },
                {'id': 1, 'alg_name': 'sbm', 'fields': [
                        {'field_name': 'alpha', 'field_value': '0.5', 'field_note': 'alphanote'
                        },
                        {'field_name': 'beta', 'field_value': '0.1', 'field_note': 'betanote'
                        }
                    ]
                }
            ], 'tasks': [
                {'id': 'segmentation', 'name': 'Segmentation'
                },
                {'id': 'classification', 'name': 'Classification'
                },
                {'id': 'detection', 'name': 'Detection'
                },
                {'id': 'depth', 'name': 'Depth'
                },
                {'id': 'keypoints', 'name': 'Keypoints'
                }
            ]
        }
        self.assertEqual(target, response.data)